"""
# Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import unittest

import paddle

from fastdeploy.model_executor.layers.moe.moe import get_moe_scores


class TestMoeRouting(unittest.TestCase):
    def setUp(self):
        paddle.seed(2024)
        print(paddle.device.cuda.get_device_properties())
        print(paddle.__git_commit__)

    def native_group_topk(
        self,
        gating_output: paddle.Tensor,
        topk: int,
        renormalize: bool,
        num_expert_group: int,
        topk_group: int,
        routed_scaling_factor: float,
        e_score_correction_bias: paddle.Tensor,
    ):
        original_scores = paddle.nn.functional.sigmoid(gating_output)
        if len(e_score_correction_bias.shape) == 1:
            e_score_correction_bias = e_score_correction_bias.unsqueeze(0)
        scores = original_scores + e_score_correction_bias

        num_token, n_experts = scores.shape
        group_scores = scores.reshape([num_token, num_expert_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
        group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1]  # [n, top_k_group]
        group_mask = paddle.zeros_like(group_scores)  # [n, n_group]
        group_mask.put_along_axis_(group_idx, 1.0, axis=-1)  # [n, n_group]
        score_mask = (
            group_mask.unsqueeze(-1)
            .expand([num_token, num_expert_group, n_experts // num_expert_group])
            .reshape([num_token, -1])
        )
        tmp_scores = scores.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))

        topk_ids = paddle.topk(tmp_scores, topk, axis=1)[1]
        topk_weights = paddle.take_along_axis(original_scores, topk_ids, axis=1)

        if renormalize:
            topk_weights = topk_weights / paddle.sum(topk_weights, axis=1, keepdim=True)

        if routed_scaling_factor != 1.0:
            topk_weights = topk_weights * routed_scaling_factor

        return topk_weights, topk_ids

    def test_group_topk(self):

        renormalize = True

        test_cases = [
            # (num_experts, n_group, topk_group, top_k, routed_scaling_factor)
            (128, 1, 1, 8, 1.0),  # glm45-air
            (256, 8, 4, 8, 2.5),  # deepseek
        ]

        for case_tuple in test_cases:
            num_experts, n_group, topk_group, top_k, routed_scaling_factor = case_tuple
            for num_tokens in [1, 32, 64, 128]:
                gating_output = paddle.rand([num_tokens, num_experts])
                e_score_correction_bias = paddle.rand([1, num_experts])

                ref_topk_values, ref_topk_idx = self.native_group_topk(
                    gating_output=gating_output,
                    topk=top_k,
                    renormalize=renormalize,
                    num_expert_group=n_group,
                    topk_group=topk_group,
                    routed_scaling_factor=routed_scaling_factor,
                    e_score_correction_bias=e_score_correction_bias,
                )

                new_score, topk_values, topk_idx = get_moe_scores(
                    gating_output=gating_output,
                    n_group=n_group,
                    topk_group=topk_group,
                    top_k=top_k,
                    routed_scaling_factor=routed_scaling_factor,
                    e_score_correction_bias=e_score_correction_bias,
                    renormalize=renormalize,
                )

                equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item()
                equal_topk_ids = paddle.allclose(
                    topk_idx.cast("int32"), ref_topk_idx.cast("int32"), atol=0.0, rtol=0.0
                ).item()
                print(
                    f"Test Case[{case_tuple}], num_tokens = {num_tokens}, equal_topk_value: {equal_topk_value}, equal_topk_ids: {equal_topk_ids}"
                )
                if not equal_topk_value:
                    print(f"ref_topk_values = {ref_topk_values}")
                    print(f"topk_values = {topk_values}")
                if not equal_topk_ids:
                    print(f"ref_topk_idx = {ref_topk_idx}")
                    print(f"topk_idx = {topk_idx}")
                assert equal_topk_value and equal_topk_ids


if __name__ == "__main__":
    unittest.main()
